{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e45f9b60-cd6b-4c15-958f-1feca5438128",
   "metadata": {},
   "source": [
    "# SQL Index Demo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119eb42b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "107396a9-4aa7-49b3-9f0f-a755726c19ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import SimpleDirectoryReader, WikipediaReader\n",
    "from IPython.display import Markdown, display"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77ac8d94-cd61-4869-a32b-0b2e7d18b83f",
   "metadata": {},
   "source": [
    "### Load Wikipedia Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93301fcf-a52b-430c-98a3-5360e6c8fc4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# install wikipedia python package\n",
    "!pip install wikipedia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ba3f7e5e-cdc4-4529-bba9-db45d8457dba",
   "metadata": {},
   "outputs": [],
   "source": [
    "wiki_docs = WikipediaReader().load_data(pages=['Toronto', 'Berlin', 'Tokyo'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "461438c8-302d-45c5-8e69-16ad604686d1",
   "metadata": {},
   "source": [
    "### Create Database Schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a370b266-66f5-4624-bbf9-2ad57f0511f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import create_engine, MetaData, Table, Column, String, Integer, select, column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ea24f794-f10b-42e6-922d-9258b7167405",
   "metadata": {},
   "outputs": [],
   "source": [
    "engine = create_engine(\"sqlite:///:memory:\")\n",
    "metadata_obj = MetaData(bind=engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b4154b29-7e23-4c26-a507-370a66186ae7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create city SQL table\n",
    "table_name = \"city_stats\"\n",
    "city_stats_table = Table(\n",
    "    table_name,\n",
    "    metadata_obj,\n",
    "    Column(\"city_name\", String(16), primary_key=True),\n",
    "    Column(\"population\", Integer),\n",
    "    Column(\"country\", String(16), nullable=False),\n",
    ")\n",
    "metadata_obj.create_all()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c09089a-6bcd-48db-8120-a84c8da3f82e",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Build Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "768d1581-b482-4c73-9963-5ffd68a2aafb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import GPTSQLStructStoreIndex, SQLDatabase, ServiceContext\n",
    "from langchain import OpenAI\n",
    "from llama_index import LLMPredictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bffabba0-8e54-4f24-ad14-2c8979c582a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name=\"text-davinci-002\"))\n",
    "service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9432787b-a8f0-4fc3-8323-e2cd9497df73",
   "metadata": {},
   "outputs": [],
   "source": [
    "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "84d4ee54-9f00-40fd-bab0-36e5e579dc9f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)).\""
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sql_database.table_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95043e10-6cdf-4f66-96bd-ce307ea7df3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOTE: the table_name specified here is the table that you\n",
    "# want to extract into from unstructured documents.\n",
    "index = GPTSQLStructStoreIndex.from_documents(\n",
    "    wiki_docs, \n",
    "    sql_database=sql_database, \n",
    "    table_name=\"city_stats\",\n",
    "    service_context=service_context\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b315b8ff-7dd7-4e7d-ac47-8c5a0c3e7ae9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('Toronto', 2731571, 'Canada'), ('Berlin', 600000, 'Germany'), ('Tokyo', 13929286, 'Japan')]\n"
     ]
    }
   ],
   "source": [
    "# view current table\n",
    "stmt = select(\n",
    "    [column(\"city_name\"), column(\"population\"), column(\"country\")]\n",
    ").select_from(city_stats_table)\n",
    "\n",
    "with engine.connect() as connection:\n",
    "    results = connection.execute(stmt).fetchall()\n",
    "    print(results)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "051a171f-8c97-40ed-ae17-4e3fa3785487",
   "metadata": {},
   "source": [
    "### Query Index"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6a2303f-3bae-4fa2-8750-03f9af747848",
   "metadata": {},
   "source": [
    "We first show a raw SQL query, which directly executes over the table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "eddd3608-31ff-4591-a02a-90987e312669",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> [query] Total LLM token usage: 0 tokens\n",
      "> [query] Total embedding token usage: 0 tokens\n"
     ]
    }
   ],
   "source": [
    "response = index.query(\"SELECT city_name from city_stats\", mode=\"sql\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4253d4c3-f3e5-4779-bcd1-2e6e2818305f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<b>[('Berlin',), ('Tokyo',), ('Toronto',)]</b>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(Markdown(f\"<b>{response}</b>\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abff69de-80c4-4fe6-afa1-b3d7208a5c4c",
   "metadata": {},
   "source": [
    "Here we show a natural language query, which is translated to a SQL query under the hood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d71045c0-7a96-4e86-b38c-c378b7759aa4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Predicted SQL query: SELECT city_name, population\n",
      "FROM city_stats\n",
      "ORDER BY population DESC\n",
      "LIMIT 1\n",
      "> [query] Total LLM token usage: 144 tokens\n",
      "> [query] Total embedding token usage: 0 tokens\n"
     ]
    }
   ],
   "source": [
    "# set Logging to DEBUG for more detailed outputs\n",
    "response = index.query(\"Which city has the highest population?\", mode=\"default\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "e713d73e-73ed-4748-8673-f476899fac8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<b>[('Tokyo', 13929286)]</b>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(Markdown(f\"<b>{response}</b>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "54a99cb0-578a-40ec-a3eb-1666ac18fbed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Tokyo', 13929286)]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# you can also fetch the raw result from SQLAlchemy! \n",
    "response.extra_info[\"result\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a567566-9828-4aa4-afde-b253c01eda69",
   "metadata": {},
   "source": [
    "### Using Langchain for Querying\n",
    "\n",
    "Since our SQLDatabase inherits from langchain, you can also use langchain itself for querying purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "8e0acde4-ca61-42e9-97f8-c9cf11502157",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain import OpenAI, SQLDatabase, SQLDatabaseChain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "30860e8b-9ad0-418c-b266-753242c1f208",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = OpenAI(temperature=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "07068a3a-30a4-4473-ba82-ab6e93e3437c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set Logging to DEBUG for more detailed outputs\n",
    "db_chain = SQLDatabaseChain(llm=llm, database=sql_database)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "a04c0a1d-f6a8-4a4a-9181-4123b09ec614",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new SQLDatabaseChain chain...\u001b[0m\n",
      "Which city has the highest population? \n",
      "SQLQuery:\u001b[32;1m\u001b[1;3m SELECT city_name FROM city_stats ORDER BY population DESC LIMIT 1;\u001b[0m\n",
      "SQLResult: \u001b[33;1m\u001b[1;3m[('Tokyo',)]\u001b[0m\n",
      "Answer:\u001b[32;1m\u001b[1;3m Tokyo has the highest population.\u001b[0m\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "' Tokyo has the highest population.'"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "db_chain.run(\"Which city has the highest population?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eeb4af4-d18e-4cf4-b1cf-1347420e24fb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
